#!/usr/bin/env python

import fire
import time
import torch
from transformers import LlamaTokenizer, AutoTokenizer
from torch_grammar import GrammarSampler

def main(grammar_file="examples/grammar.ebnf"):
    tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b")
    tokenizer = AutoTokenizer.from_pretrained("WizardLM/WizardCoder-15B-V1.0")

    with open(grammar_file, "r") as file:
      input_text = file.read()
    grammar = GrammarSampler(input_text, "root", tokenizer)

    ids = [[]]
    if tokenizer.bos_token_id is not None:
        ids[0].append(tokenizer.bos_token_id)
    logits_processor = grammar.logits_processor()

    # torch.manual_seed(1111) # comment out the line below too to reproduce an error
    print(f"\x1b[3;36mtorch seed: {torch.seed()}\x1b[0m")

    vocab_size = len(tokenizer.get_vocab())
    for i in range(10):
      logits = torch.randn((1, vocab_size))
      logits = logits_processor(ids, logits)
      token = torch.argmax(logits).item()
      ids[0].append(token)
      if token == tokenizer.eos_token_id and i > 0:
          break
    print(f"\x1b[1mfirst 10 tokens: \x1b[1;35m{tokenizer.decode(ids[0])}\x1b[0m")

    st = time.time()
    n1000 = 0
    for i in range(1000):
        if ids[0][-1] == tokenizer.eos_token_id:
            break
        n1000 += 1
        logits = torch.randn((1, vocab_size))
        try:
            logits = logits_processor(ids, logits)
            token = torch.argmax(logits).item()
        except AssertionError as e:
            print(f"\x1b[0;31m error accepting token {token}: '{tokenizer.decode([token])}' / '{tokenizer.convert_ids_to_tokens(token)}'\x1b[0m")
            print(f"\x1b[0;31m{tokenizer.decode(ids[0])}\x1b[0m")
            raise (e)
        ids[0].append(token)
        if token == tokenizer.eos_token_id:
            break
    et = time.time() - st
    if n1000 != 0:
        avg = et / n1000
        print(
          f"\x1b[1;34mµ={et*1000:.0f}µs\x1b[0;1m\tfor post-warmup tokens (n={n1000})\x1b[0m"
        )
        n = grammar.nt
        ms = 1000 * (grammar.tt / grammar.nt)
        print(
          f"\x1b[1;34mµ={ms:.0f}ms\x1b[0;1m\tfor stack acceptance calculation (n={n})\x1b[0m"
        )

if __name__ == "__main__":
    fire.Fire(main)
